When calling a tensorflow.function
behind the scenes a ConcreteFunction
is created everytime a new value is passed as
argument. This is not the case with Python global variables, closure or nonlocal variables.
This means the state and the result of the tensorflow.function
may not be what is expected.
import tensorflow as tf
@tf.function
def addition():
return 1 + foo
foo = 4
addition() # tf.Tensor(5, shape=(), dtype=int32): on this first step we obtain the expected result
foo = 10
addition() # tf.Tensor(5, shape=(), dtype=int32): unexpected result of 5 instead of 11
As we can see in the example above the second time addition
is called, we obtain the same result as the first call. This is due to the
fact that between the 2 calls of addition
the value of the argument passed to the function did not change. This result in the creation of
a single ConcreteFunction
during the first call of addition
, with the value of foo set to 4.
This is why it is a good practice to not use and mutate global variables or nonlocal variables inside of a tensorflow.function
.
Exceptions
This rule will not raise an issue if the global or nonlocal variable is a tensorflow.Variable
.
import tensorflow as tf
@tf.function
def addition():
return 1 + foo
foo = tf.Variable(4)
addition()
foo.assign(10)
addition()
In this case the ConcreteFunction
will be created properly each call if the value of the variable changes.